import asyncio
from collections import namedtuple
from pathlib import Path
import janus
import queue
import sys
import threading
import uuid

from .tracer import trace
from .utils import (
    detect_fts,
    detect_primary_keys,
    detect_spatialite,
    get_all_foreign_keys,
    get_outbound_foreign_keys,
    sqlite_timelimit,
    sqlite3,
    table_columns,
    table_column_details,
)
from .inspect import inspect_hash

connections = threading.local()

AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file"))


class Database:
    def __init__(
        self, ds, path=None, is_mutable=True, is_memory=False, memory_name=None
    ):
        self.name = None
        self.route = None
        self.ds = ds
        self.path = path
        self.is_mutable = is_mutable
        self.is_memory = is_memory
        self.memory_name = memory_name
        if memory_name is not None:
            self.is_memory = True
        self.hash = None
        self.cached_size = None
        self._cached_table_counts = None
        self._write_thread = None
        self._write_queue = None
        # These are used when in non-threaded mode:
        self._read_connection = None
        self._write_connection = None
        if not self.is_mutable and not self.is_memory:
            if self.ds.inspect_data and self.ds.inspect_data.get(self.name):
                self.hash = self.ds.inspect_data[self.name]["hash"]
                self.cached_size = self.ds.inspect_data[self.name]["size"]
            else:
                p = Path(path)
                self.hash = inspect_hash(p)
                self.cached_size = p.stat().st_size

    @property
    def cached_table_counts(self):
        if self._cached_table_counts is not None:
            return self._cached_table_counts
        # Maybe use self.ds.inspect_data to populate cached_table_counts
        if self.ds.inspect_data and self.ds.inspect_data.get(self.name):
            self._cached_table_counts = {
                key: value["count"]
                for key, value in self.ds.inspect_data[self.name]["tables"].items()
            }
        return self._cached_table_counts

    def suggest_name(self):
        if self.path:
            return Path(self.path).stem
        elif self.memory_name:
            return self.memory_name
        else:
            return "db"

    def connect(self, write=False):
        if self.memory_name:
            uri = "file:{}?mode=memory&cache=shared".format(self.memory_name)
            conn = sqlite3.connect(
                uri,
                uri=True,
                check_same_thread=False,
            )
            if not write:
                conn.execute("PRAGMA query_only=1")
            return conn
        if self.is_memory:
            return sqlite3.connect(":memory:", uri=True)
        # mode=ro or immutable=1?
        if self.is_mutable:
            qs = "?mode=ro"
            if self.ds.nolock:
                qs += "&nolock=1"
        else:
            qs = "?immutable=1"
        assert not (write and not self.is_mutable)
        if write:
            qs = ""
        return sqlite3.connect(
            f"file:{self.path}{qs}", uri=True, check_same_thread=False
        )

    async def execute_write(self, sql, params=None, block=True):
        def _inner(conn):
            with conn:
                return conn.execute(sql, params or [])

        with trace("sql", database=self.name, sql=sql.strip(), params=params):
            results = await self.execute_write_fn(_inner, block=block)
        return results

    async def execute_write_script(self, sql, block=True):
        def _inner(conn):
            with conn:
                return conn.executescript(sql)

        with trace("sql", database=self.name, sql=sql.strip(), executescript=True):
            results = await self.execute_write_fn(_inner, block=block)
        return results

    async def execute_write_many(self, sql, params_seq, block=True):
        def _inner(conn):
            count = 0

            def count_params(params):
                nonlocal count
                for param in params:
                    count += 1
                    yield param

            with conn:
                return conn.executemany(sql, count_params(params_seq)), count

        with trace(
            "sql", database=self.name, sql=sql.strip(), executemany=True
        ) as kwargs:
            results, count = await self.execute_write_fn(_inner, block=block)
            kwargs["count"] = count
        return results

    async def execute_write_fn(self, fn, block=True):
        if self.ds.executor is None:
            # non-threaded mode
            if self._write_connection is None:
                self._write_connection = self.connect(write=True)
                self.ds._prepare_connection(self._write_connection, self.name)
            return fn(self._write_connection)

        # threaded mode
        task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io")
        if self._write_queue is None:
            self._write_queue = queue.Queue()
        if self._write_thread is None:
            self._write_thread = threading.Thread(
                target=self._execute_writes, daemon=True
            )
            self._write_thread.start()
        reply_queue = janus.Queue()
        self._write_queue.put(WriteTask(fn, task_id, reply_queue))
        if block:
            result = await reply_queue.async_q.get()
            if isinstance(result, Exception):
                raise result
            else:
                return result
        else:
            return task_id

    def _execute_writes(self):
        # Infinite looping thread that protects the single write connection
        # to this database
        conn_exception = None
        conn = None
        try:
            conn = self.connect(write=True)
            self.ds._prepare_connection(conn, self.name)
        except Exception as e:
            conn_exception = e
        while True:
            task = self._write_queue.get()
            if conn_exception is not None:
                result = conn_exception
            else:
                try:
                    result = task.fn(conn)
                except Exception as e:
                    sys.stderr.write("{}\n".format(e))
                    sys.stderr.flush()
                    result = e
            task.reply_queue.sync_q.put(result)

    async def execute_fn(self, fn):
        if self.ds.executor is None:
            # non-threaded mode
            if self._read_connection is None:
                self._read_connection = self.connect()
                self.ds._prepare_connection(self._read_connection, self.name)
            return fn(self._read_connection)

        # threaded mode
        def in_thread():
            conn = getattr(connections, self.name, None)
            if not conn:
                conn = self.connect()
                self.ds._prepare_connection(conn, self.name)
                setattr(connections, self.name, conn)
            return fn(conn)

        return await asyncio.get_event_loop().run_in_executor(
            self.ds.executor, in_thread
        )

    async def execute(
        self,
        sql,
        params=None,
        truncate=False,
        custom_time_limit=None,
        page_size=None,
        log_sql_errors=True,
    ):
        """Executes sql against db_name in a thread"""
        page_size = page_size or self.ds.page_size

        def sql_operation_in_thread(conn):
            time_limit_ms = self.ds.sql_time_limit_ms
            if custom_time_limit and custom_time_limit < time_limit_ms:
                time_limit_ms = custom_time_limit

            with sqlite_timelimit(conn, time_limit_ms):
                try:
                    cursor = conn.cursor()
                    cursor.execute(sql, params if params is not None else {})
                    max_returned_rows = self.ds.max_returned_rows
                    if max_returned_rows == page_size:
                        max_returned_rows += 1
                    if max_returned_rows and truncate:
                        rows = cursor.fetchmany(max_returned_rows + 1)
                        truncated = len(rows) > max_returned_rows
                        rows = rows[:max_returned_rows]
                    else:
                        rows = cursor.fetchall()
                        truncated = False
                except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
                    if e.args == ("interrupted",):
                        raise QueryInterrupted(e, sql, params)
                    if log_sql_errors:
                        sys.stderr.write(
                            "ERROR: conn={}, sql = {}, params = {}: {}\n".format(
                                conn, repr(sql), params, e
                            )
                        )
                        sys.stderr.flush()
                    raise

            if truncate:
                return Results(rows, truncated, cursor.description)

            else:
                return Results(rows, False, cursor.description)

        with trace("sql", database=self.name, sql=sql.strip(), params=params):
            results = await self.execute_fn(sql_operation_in_thread)
        return results

    @property
    def size(self):
        if self.is_memory:
            return 0
        if self.cached_size is not None:
            return self.cached_size
        else:
            return Path(self.path).stat().st_size

    async def table_counts(self, limit=10):
        if not self.is_mutable and self.cached_table_counts is not None:
            return self.cached_table_counts
        # Try to get counts for each table, $limit timeout for each count
        counts = {}
        for table in await self.table_names():
            try:
                table_count = (
                    await self.execute(
                        f"select count(*) from [{table}]",
                        custom_time_limit=limit,
                    )
                ).rows[0][0]
                counts[table] = table_count
            # In some cases I saw "SQL Logic Error" here in addition to
            # QueryInterrupted - so we catch that too:
            except (QueryInterrupted, sqlite3.OperationalError, sqlite3.DatabaseError):
                counts[table] = None
        if not self.is_mutable:
            self._cached_table_counts = counts
        return counts

    @property
    def mtime_ns(self):
        if self.is_memory:
            return None
        return Path(self.path).stat().st_mtime_ns

    async def attached_databases(self):
        # This used to be:
        #   select seq, name, file from pragma_database_list() where seq > 0
        # But SQLite prior to 3.16.0 doesn't support pragma functions
        results = await self.execute("PRAGMA database_list;")
        # {'seq': 0, 'name': 'main', 'file': ''}
        return [AttachedDatabase(*row) for row in results.rows if row["seq"] > 0]

    async def table_exists(self, table):
        results = await self.execute(
            "select 1 from sqlite_master where type='table' and name=?", params=(table,)
        )
        return bool(results.rows)

    async def table_names(self):
        results = await self.execute(
            "select name from sqlite_master where type='table'"
        )
        return [r[0] for r in results.rows]

    async def table_columns(self, table):
        return await self.execute_fn(lambda conn: table_columns(conn, table))

    async def table_column_details(self, table):
        return await self.execute_fn(lambda conn: table_column_details(conn, table))

    async def primary_keys(self, table):
        return await self.execute_fn(lambda conn: detect_primary_keys(conn, table))

    async def fts_table(self, table):
        return await self.execute_fn(lambda conn: detect_fts(conn, table))

    async def label_column_for_table(self, table):
        explicit_label_column = self.ds.table_metadata(self.name, table).get(
            "label_column"
        )
        if explicit_label_column:
            return explicit_label_column
        column_names = await self.execute_fn(lambda conn: table_columns(conn, table))
        # Is there a name or title column?
        name_or_title = [c for c in column_names if c.lower() in ("name", "title")]
        if name_or_title:
            return name_or_title[0]
        # If a table has two columns, one of which is ID, then label_column is the other one
        if (
            column_names
            and len(column_names) == 2
            and ("id" in column_names or "pk" in column_names)
        ):
            return [c for c in column_names if c not in ("id", "pk")][0]
        # Couldn't find a label:
        return None

    async def foreign_keys_for_table(self, table):
        return await self.execute_fn(
            lambda conn: get_outbound_foreign_keys(conn, table)
        )

    async def hidden_table_names(self):
        # Mark tables 'hidden' if they relate to FTS virtual tables
        hidden_tables = [
            r[0]
            for r in (
                await self.execute(
                    """
                select name from sqlite_master
                where rootpage = 0
                and (
                    sql like '%VIRTUAL TABLE%USING FTS%'
                ) or name in ('sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4')
            """
                )
            ).rows
        ]
        has_spatialite = await self.execute_fn(detect_spatialite)
        if has_spatialite:
            # Also hide Spatialite internal tables
            hidden_tables += [
                "ElementaryGeometries",
                "SpatialIndex",
                "geometry_columns",
                "spatial_ref_sys",
                "spatialite_history",
                "sql_statements_log",
                "sqlite_sequence",
                "views_geometry_columns",
                "virts_geometry_columns",
                "data_licenses",
                "KNN",
                "KNN2",
            ] + [
                r[0]
                for r in (
                    await self.execute(
                        """
                        select name from sqlite_master
                        where name like "idx_%"
                        and type = "table"
                    """
                    )
                ).rows
            ]
        # Add any from metadata.json
        db_metadata = self.ds.metadata(database=self.name)
        if "tables" in db_metadata:
            hidden_tables += [
                t
                for t in db_metadata["tables"]
                if db_metadata["tables"][t].get("hidden")
            ]
        # Also mark as hidden any tables which start with the name of a hidden table
        # e.g. "searchable_fts" implies "searchable_fts_content" should be hidden
        for table_name in await self.table_names():
            for hidden_table in hidden_tables[:]:
                if table_name.startswith(hidden_table):
                    hidden_tables.append(table_name)
                    continue

        return hidden_tables

    async def view_names(self):
        results = await self.execute("select name from sqlite_master where type='view'")
        return [r[0] for r in results.rows]

    async def get_all_foreign_keys(self):
        return await self.execute_fn(get_all_foreign_keys)

    async def get_table_definition(self, table, type_="table"):
        table_definition_rows = list(
            await self.execute(
                "select sql from sqlite_master where name = :n and type=:t",
                {"n": table, "t": type_},
            )
        )
        if not table_definition_rows:
            return None
        bits = [table_definition_rows[0][0] + ";"]
        # Add on any indexes
        index_rows = list(
            await self.execute(
                "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null",
                {"n": table},
            )
        )
        for index_row in index_rows:
            bits.append(index_row[0] + ";")
        return "\n".join(bits)

    async def get_view_definition(self, view):
        return await self.get_table_definition(view, "view")

    def __repr__(self):
        tags = []
        if self.is_mutable:
            tags.append("mutable")
        if self.is_memory:
            tags.append("memory")
        if self.hash:
            tags.append(f"hash={self.hash}")
        if self.size is not None:
            tags.append(f"size={self.size}")
        tags_str = ""
        if tags:
            tags_str = f" ({', '.join(tags)})"
        return f"<Database: {self.name}{tags_str}>"


class WriteTask:
    __slots__ = ("fn", "task_id", "reply_queue")

    def __init__(self, fn, task_id, reply_queue):
        self.fn = fn
        self.task_id = task_id
        self.reply_queue = reply_queue


class QueryInterrupted(Exception):
    def __init__(self, e, sql, params):
        self.e = e
        self.sql = sql
        self.params = params


class MultipleValues(Exception):
    pass


class Results:
    def __init__(self, rows, truncated, description):
        self.rows = rows
        self.truncated = truncated
        self.description = description

    @property
    def columns(self):
        return [d[0] for d in self.description]

    def first(self):
        if self.rows:
            return self.rows[0]
        else:
            return None

    def single_value(self):
        if self.rows and 1 == len(self.rows) and 1 == len(self.rows[0]):
            return self.rows[0][0]
        else:
            raise MultipleValues

    def __iter__(self):
        return iter(self.rows)

    def __len__(self):
        return len(self.rows)
